// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 werbenhu
// SPDX-FileContributor: werbenhu

package digo

import (
	"encoding/json"
	"errors"
	"fmt"
	"go/ast"
	"go/token"
	"log"
	"os"
	"regexp"
	"sort"
	"strings"

	"golang.org/x/tools/go/packages"
)

const (
	// RegexpText represents the regular expression pattern for parsing annotations.
	RegexpText = `^//\s*@(provider|inject|group)\s*\((.*)\s*\)`
)

// chain represents the dependency chain of a provider and is used to determine whether there is a cyclic dependency.
// chain是provider的依赖链，用来判断是否有循环依赖
type chain []*DiFunc

// newChain creates a new empty dependency chain.
func newChain() chain {
	return make(chain, 0)
}

// clone returns a copy of the chain.
// clone 复制当前依赖链
func (c chain) clone() chain {
	newly := make(chain, len(c))
	copy(newly, c)
	return newly
}

// String returns a textual representation of the dependency relationship.
// String 返回表示依赖关系的文字信息
func (c chain) String() string {
	if len(c) == 0 {
		return ""
	}
	var p string
	separator := " -> "
	for _, fn := range c {
		p = p + fn.ProviderId + separator
	}
	return p[:len(p)-len(separator)]
}

// insert inserts a new provider into the current dependency chain.
// If the provider is already present in the dependency chain, indicating a cyclic dependency, it returns false.
// insert 往当前依赖链中插入一个新来的provider
// 如果发现当前依赖链中已经存在该provider，则表明有循环依赖，返回false
func (c *chain) insert(fn *DiFunc) bool {
	for _, f := range *c {
		if f.ProviderId == fn.ProviderId {
			*c = append(*c, fn)
			return false
		}
	}
	*c = append(*c, fn)
	return true
}

// Provider represents a provider.
type Provider struct {
	Id string // Id represents the identifier of the provider.
}

// Member represents a member in a group.
type Member struct {
	GroupId string `json:"id"` // GroupId represents the group ID of the member.
}

// Injector represents an injector parameter.
type Injector struct {
	ProviderId string `json:"id"` // Id represents the identifier of the provider.
	Pkg        string
	Param      string // Param represents the parameter name.
	Alias      string

	Typ        ast.Expr // Typ represents the type of the parameter.
	Dependency *DiFunc
}

// GetObjName returns the temporary variable name for the injector parameter, which has the "any" type.
// GetObjName 返回injector注解的参数的临时变量名，这个变量是any类型.
func (i *Injector) GetObjName() string {
	return i.Param + "_obj"
}

// GetArgName returns the variable name of the injector parameter, which is passed to the provider's constructor.
// GetArgName 返回injector注解的变量名，这个变量用来传递给provider的构造函数
func (i *Injector) GetArgName() string {
	return i.Param
}

// replaceSeparator replaces '.' or '/' in the ID with underscores.
// replaceSeparator 替换id中的.或者/为下划线
func replaceSeparator(id string) string {
	name := strings.ReplaceAll(id, ".", "_")
	name = strings.ReplaceAll(name, "/", "_")
	return name
}

// DiFunc represents a function with valid annotations.
// DiFunc表示被合法的注解的函数
type DiFunc struct {
	Name       string
	Injectors  []*Injector
	ProviderId string
	GroupId    string
	Sort       int
	Package    *DiPackage
	File       *DiFile
}

// NewDiFunc creates a new DiFunc instance.
func NewDiFunc(pkg *DiPackage, file *DiFile, name string) *DiFunc {
	return &DiFunc{
		Name:      name,
		Package:   pkg,
		File:      file,
		Injectors: make([]*Injector, 0),
	}
}

// providerArgName returns the argument name for the provider's constructor function.
func (fn *DiFunc) providerArgName() string {
	return replaceSeparator(fn.ProviderId)
}

// providerObjName returns the temporary variable name created by the provider's constructor function.
// This variable holds the object of type 'any'.
// GetArgName 返回provider构造函数创建的临时变量名，这个变量是any类型
func (fn *DiFunc) providerObjName() string {
	return replaceSeparator(fn.ProviderId) + "_obj"
}

// providerFuncName returns the name of the initialization function generated by the provider for registering the provider.
// providerFuncName 返回provider生成的注册provider的初始化函数名
func (fn *DiFunc) providerFuncName() string {
	return "init_" + replaceSeparator(fn.ProviderId)
}

// groupFuncName returns the name of the initialization function generated by the provider for registering into the group.
// providerFuncName 返回provider生成的注册到group的初始化函数名
func (fn *DiFunc) groupFuncName() string {
	return "group_" + replaceSeparator(fn.GroupId) + "_" + fn.Name
}

// DiFuncs represents an array of functions with valid annotations.
// DiFuncs 表示被合法的注解的函数数组
type DiFuncs []*DiFunc

// Len returns the length of the funcs.
func (idx DiFuncs) Len() int {
	return len(idx)
}

// Swap swaps the elements at positions i and j in the funcs.
func (idx DiFuncs) Swap(i, j int) {
	idx[i], idx[j] = idx[j], idx[i]
}

// Less compares the elements at positions i and j in the slice
// and returns true if the element at position i is less than the element at position j.
func (idx DiFuncs) Less(i, j int) bool {
	return idx[i].Sort > (idx[j]).Sort
}

// Sort sorts the elements in the slice in reverse order.
func (idx DiFuncs) Sort() {
	sort.Sort(idx)
}

// DiImport represents an import statement in Go source code, recording the imported package name and path.
// DiImport 表示go源代码中的一个improt语句信息，记录引入的包名和别名
type DiImport struct {
	Name string // name represents the package name.
	Path string // Path represents the import path.
}

// DiFile represents a file in Go source code.
// DiFile 表示一个go源代码的一个文件
type DiFile struct {
	Name    string
	Package *DiPackage
	Imports map[string]*DiImport
}

func NewDiFile(pkg *DiPackage, name string) *DiFile {
	return &DiFile{
		Name:    name,
		Package: pkg,
		Imports: make(map[string]*DiImport),
	}
}

// DiPackage represents a package in Go source code.
// DiPackage 表示一个go源代码的一个包
type DiPackage struct {
	Name   string
	Path   string
	Folder string
	Funcs  DiFuncs
	Files  map[string]*DiFile
}

func NewDiPackage(name string, path string, folder string) *DiPackage {
	return &DiPackage{
		Name:   name,
		Path:   path,
		Folder: folder,
		Funcs:  make(DiFuncs, 0),
		Files:  make(map[string]*DiFile),
	}
}

// findProvider finds a provider by its ID within a package.
// findProvider 根据provider id 从一个包中查找查找provider
func (pkg *DiPackage) findProvider(id string) *DiFunc {
	for _, fn := range pkg.Funcs {
		if id == fn.ProviderId {
			return fn
		}
	}
	return nil
}

// Parser represents a parser for source code parsing.
// Parser表示解析源码的解析器
type Parser struct {
	Packages []*DiPackage
}

func NewParser() *Parser {
	return &Parser{
		Packages: make([]*DiPackage, 0),
	}
}

// parseImports analyzes and extracts information about imported packages in the current file.
// parseImports分析并提取当前文件中导入的包的信息。
func (p *Parser) parseImports(pkg *DiPackage, file *DiFile, decl *ast.GenDecl) {
	if decl.Tok == token.IMPORT {
		for _, spec := range decl.Specs {
			importSpec, ok := spec.(*ast.ImportSpec)
			if !ok {
				return
			}
			if importSpec.Name != nil && len(importSpec.Name.Name) > 0 {
				file.Imports[importSpec.Name.Name] = &DiImport{
					Name: importSpec.Name.Name,
					Path: strings.ReplaceAll(importSpec.Path.Value, "\"", ""),
				}
			} else {
				splitted := strings.Split(importSpec.Path.Value, "/")
				name := strings.ReplaceAll(splitted[len(splitted)-1], "\"", "")
				file.Imports[name] = &DiImport{
					Path: strings.ReplaceAll(importSpec.Path.Value, "\"", ""),
				}
			}
		}
	}
}

// matchComment matches comments that comply with the provider, inject, and group rules.
// It returns the annotation type and the JSON-formatted content of the annotation.
// matchComment匹配符合provider、inject、group规则的注释。
// 返回注解类型和注解的JSON格式内容。
func (p *Parser) matchComment(comment string) (name string, body string) {
	r := regexp.MustCompile(RegexpText)
	if matches := r.FindStringSubmatch(comment); matches != nil {
		name = matches[1]
		body = matches[2]
	}
	return
}

// findProvider finds a provider based on its ID.
// 根据ID查找provider。
func (p *Parser) findProvider(id string) *DiFunc {
	for _, pkg := range p.Packages {
		fn := pkg.findProvider(id)
		if fn != nil {
			return fn
		}
	}
	return nil
}

// parseProvider analyzes and extracts all the @provider annotations in the source code,
// and saves the annotation information in the Provider object.
// parseProvider分析提取源码中所有的@provider注解，并将注解信息保存在Provider对象中。
func (p *Parser) parseProvider(body string, fn *DiFunc) error {
	provider := &Provider{}
	if err := json.Unmarshal([]byte(body), provider); err != nil {
		return fmt.Errorf("wrong JSON format: %s", err.Error())
	}

	if p.findProvider(provider.Id) != nil || fn.Package.findProvider(provider.Id) != nil {
		return fmt.Errorf("[ERROR] duplicate provider ID: %s", provider.Id)
	}
	fn.ProviderId = provider.Id
	return nil
}

// parseInject analyzes all the @inject annotations in the source code and extracts the inject information into an Injector object.
// parseInject 分析源码码中所有的@inject注解，并将inject信息提取到Injector对象中
func (p *Parser) parseInject(body string, fn *DiFunc, decl *ast.FuncDecl) error {
	injector := &Injector{}
	if err := json.Unmarshal([]byte(body), injector); err != nil {
		return err
	}

	// Traverse the AST tree of the current function's source code and check if all annotations can be found in the parameter list.
	// 遍历当前函数的源码ast数，检测所有的注解在参数列表中是否都能找到
	for i, field := range decl.Type.Params.List {
		for _, name := range field.Names {
			if name.Name == injector.Param {
				injector.Typ = decl.Type.Params.List[i].Type
			}
		}
	}
	if injector.Typ == nil {
		return errors.New("injected parameter is not found")
	}

	// The @inject annotation can explicitly specify the package name for the variable,
	// e.g., @inject({"param": "mq", "id": "mq", "pkg": "github.com/mochi-co/mqtt/v2"}).
	// If the package to be imported is explicitly specified in the @inject annotation, there is no need to search for imported packages.
	// @inject注解可以显式的设置变量对应的包名，
	// 比如@inject({"param":"mq", "id":"mq", "pkg": "github.com/mochi-co/mqtt/v2"})
	// 如果@inject注解里显式的表明了该param需要引入的包，则不需要再去查找improt的包了
	if len(injector.Pkg) > 0 {
		fn.Injectors = append(fn.Injectors, injector)
		return nil
	}

	// If the parameter type of the injector is not defined in the current package, it requires importing packages from elsewhere.
	// For example, if the parameter type is "eventbus.EventBus", which is defined in the package "github.com/werbenhu/eventbus",
	// we need to find this package in the import list of the current file.
	// 如果injector里的param对应的参数的类型不是当前包下定义的，需要引入别的地方的包
	// 比如需要注入一个参数类型是: eventbus.EventBus, 这个类型是包github.com/werbenhu/eventbus里定义的
	// 这里需要从当前文件的import列表中，找出这个包名
	var importPkg string

	// isPkgFound indicates whether the package required for the variable has been found in the import list of the current file.
	// isPkgFound标识是否从当前文件的import列表中找到了该变量需要使用的包
	var isPkgFound bool

	// Determine the type of the injected parameter, which can be one of the following: regular types (int, string, struct, etc.),
	// pointer types (*type), compound pointer types (*pkg.type), or compound regular types (pkg.type).
	// 判断inject的参数的类型，类型可能是下面三种 (int, string, struct等)普通类型，
	// *type指针类型， *pkg.type复合指针类型, pkg.type复合普通类型
	if starExpr, ok := injector.Typ.(*ast.StarExpr); ok {

		// If it is a compound pointer type (*pkg.type), it indicates that this parameter requires importing a package.
		// 如果*pkg.type复合指针类型，说明这个参数需要引入包
		if selExpr, ok := starExpr.X.(*ast.SelectorExpr); ok {
			importPkg = selExpr.X.(*ast.Ident).Name

			// Find the package path and alias in the import list of the source code file for the package required by this parameter, and save it in the injector.
			// 从源码文件的import列表中，找出该参数需要引入的包的路径和别名保存在injector中
			if impor, ok := fn.File.Imports[importPkg]; ok {
				injector.Pkg = impor.Path
				injector.Alias = impor.Name
				isPkgFound = true
			}

		} else if _, ok := starExpr.X.(*ast.Ident); ok {
			// If it is a pointer type (*type), it indicates that this parameter uses a struct defined in the current package, and no additional package needs to be imported.
			// 如果是*type指针类型，说明这个参数使用的当前包中定义的struct，这种类型不需要额外的引入包了
			isPkgFound = true
		}

	} else if selExpr, ok := injector.Typ.(*ast.SelectorExpr); ok {
		// If it is a compound regular type (pkg.struct), it indicates that this parameter requires importing a package from elsewhere.
		// 如果是pkg.struct这种类型的参数，说明这个参数需要引入别的地方的包
		importPkg = selExpr.X.(*ast.Ident).Name

		// Find the package required by this parameter in the import list of the current file,
		// and save the package path and alias in the injector.
		// If an alias is used when importing the package, it also needs to be saved here.
		// 从当前文件的import列表中，将需要引入的包找出来，并放在injector中
		// 如果引入包的时候使用了别名，那么这里别名也需要保存
		if selExpr.X != nil {
			if impor, ok := fn.File.Imports[importPkg]; ok {
				injector.Pkg = impor.Path
				injector.Alias = impor.Name
				isPkgFound = true
			}
		}

	} else if _, ok := injector.Typ.(*ast.Ident); ok {
		// If it is a regular type (int, string, struct, etc.), it indicates that this parameter uses
		// a type defined in the current package, and no additional package needs to be imported.
		// 如果是(int, string, struct等)普通类型，说明这个参数使用的当前包中定义的type，不需要额外的引入包了
		isPkgFound = true
	}

	if !isPkgFound {
		return errors.New("injected parameter's package not found")
	}
	fn.Injectors = append(fn.Injectors, injector)
	return nil
}

// parseGroup analyzes and extracts all the @group annotations in the source code and saves the annotation information in a Member object.
// parseGroup分析提取源码中的所有@group注解，并将注解信息保存在Member对象中。
func (p *Parser) parseGroup(body string, fn *DiFunc) error {
	member := &Member{}
	if err := json.Unmarshal([]byte(body), member); err != nil {
		return err
	}
	fn.GroupId = member.GroupId
	return nil
}

// parseFunc analyzes the annotations of a specific function and extracts the provider, inject, and group information.
// parseFunc分析某个函数的注解，提取出provider、inject、group信息。
func (p *Parser) parseFunc(pkg *DiPackage, fn *DiFunc, decl *ast.FuncDecl) error {

	// If the function code has comments
	// 如果源码注释不为空
	if decl.Doc != nil && decl.Doc.List != nil {
		for _, comment := range decl.Doc.List {
			// Use regular expressions to match the text of the comment
			// 用正则表达式匹配注释的文本
			name, body := p.matchComment(comment.Text)
			switch name {
			case "provider":
				if err := p.parseProvider(body, fn); err != nil {
					return fmt.Errorf("failed to parse provider annotation, %s in package: %s Func: %s", err.Error(), pkg.Path, fn.Name)
				}
			case "inject":
				if err := p.parseInject(body, fn, decl); err != nil {
					return fmt.Errorf("failed to parse inject annotation, %s in package: %s Func: %s", err.Error(), pkg.Path, fn.Name)
				}
			case "group":
				if err := p.parseGroup(body, fn); err != nil {
					return fmt.Errorf("failed to parse group annotation, %s in package: %s Func: %s", err.Error(), pkg.Path, fn.Name)
				}
			}
		}
	}

	if len(fn.ProviderId) == 0 && len(fn.GroupId) == 0 {
		return nil
	}

	// Check if all parameters of the function have been injected
	// 检查是否函数的所有参数都被注入了
	for _, field := range decl.Type.Params.List {
		for _, name := range field.Names {
			found := false
			for _, injector := range fn.Injectors {
				if name.String() == injector.Param {
					found = true
					break
				}
			}
			if !found {
				return fmt.Errorf("all parameters of the provider must be injected, param: %v have not been injected yet, in pkg: %s, function: %s",
					name.String(), pkg.Path, fn.Name)
			}
		}
	}
	return nil
}

// parse analyzes the comments of functions in all packages and extracts the information of imported packages for each file.
// parse 解析所有包下函数的注释，并且提取出每个文件的import的包的信息
func (p *Parser) parse(pkgs []*packages.Package) error {
	for _, pkg := range pkgs {

		splitted := strings.Split(pkg.GoFiles[0], string(os.PathSeparator))
		folder := strings.Join(splitted[:len(splitted)-1], string(os.PathSeparator))
		diPkg := NewDiPackage(pkg.Name, pkg.PkgPath, folder)

		for _, syntax := range pkg.Syntax {
			diFile := NewDiFile(diPkg, syntax.Name.String())
			for _, decl := range syntax.Decls {

				if genDecl, ok := decl.(*ast.GenDecl); ok {
					p.parseImports(diPkg, diFile, genDecl)
				} else if fn, ok := decl.(*ast.FuncDecl); ok {

					diFunc := NewDiFunc(diPkg, diFile, fn.Name.String())
					if err := p.parseFunc(diPkg, diFunc, fn); err != nil {
						return err
					}

					if len(diFunc.ProviderId) > 0 || len(diFunc.GroupId) > 0 {
						diPkg.Funcs = append(diPkg.Funcs, diFunc)
					}
				}
			}
			if len(diPkg.Funcs) > 0 {
				diPkg.Files[syntax.Name.String()] = diFile
			}
		}

		if len(diPkg.Funcs) > 0 {
			p.Packages = append(p.Packages, diPkg)
		}
	}
	return nil
}

// findProviderById finds a provider by ID.
// findProviderById 根据id查找provider
func (p *Parser) findProviderById(id string) *DiFunc {
	for _, pkg := range p.Packages {
		for _, fn := range pkg.Funcs {
			if fn.ProviderId == id {
				return fn
			}
		}
	}
	return nil
}

// checkInjectorLegal checks if the injected object is legal and returns false if the required provider does not exist.
// checkInjectorLegal 检查注入的对象是否合法，如果需要注入的provider不存在则返回false
func (p *Parser) checkInjectorLegal() bool {
	for _, pkg := range p.Packages {
		for _, fn := range pkg.Funcs {
			// Find the provider to which each injector belongs.
			// 查找出每个injector所归属的provider
			for _, injector := range fn.Injectors {
				provider := p.findProviderById(injector.ProviderId)
				if provider == nil {
					log.Printf("[ERROR] provider id:%s not found, used in package:%s, func:%s, param:%s",
						injector.ProviderId, pkg.Path, fn.Name, injector.Param)
					return false
				}
				injector.Dependency = provider
			}
		}
	}
	return true
}

// increaseProviderPrioritys searches for all providers that a provider depends on,
// and increases the priority of the dependent providers. The chain is used to record the dependency chain.
// increaseProviderPrioritys 查找某个provider依赖的所有provider，
// 并将依赖的provider的优先级提高，chain用来记录依赖链
func (p *Parser) increaseProviderPrioritys(c chain, fn *DiFunc) bool {
	if !c.insert(fn) {
		log.Printf("[ERROR] provider circular injection: %s\n", c.String())
		return false
	}

	// Find all injectors of the provider.
	// 找出provider所有的injectors
	for _, injector := range fn.Injectors {
		clone := c.clone()
		if injector.Dependency != nil {
			injector.Dependency.Sort++
			if !p.increaseProviderPrioritys(clone, injector.Dependency) {
				return false
			}
		}
	}
	return true
}

// checkCyclicProvider traverses all providers to check if there is a circular dependency between two providers.
// During the checking process, it increases the priority of the providers being depended on.
// checkCyclicProvider 遍历所有的provider，检测是否有两个provider循环依赖，检测的过程中会提高被依赖的provider的优先级
func (p *Parser) checkCyclicProvider() bool {
	for _, pkg := range p.Packages {
		for _, fn := range pkg.Funcs {
			c := newChain()
			if !p.increaseProviderPrioritys(c, fn) {
				return false
			}
		}
		pkg.Funcs.Sort()
	}
	return true
}

// Start initiates the annotation analysis, generates Go code, and writes it to files.
// Start 启动分析注解，并生成go代码，写入到文件中
func (p *Parser) Start() {
	// Load packages and their syntax.
	pkgs, err := packages.Load(&packages.Config{
		Mode: packages.LoadAllSyntax,
	}, "pattern=./...")

	if err != nil {
		fmt.Fprintf(os.Stderr, "load: %v\n", err)
		os.Exit(1)
	}
	if packages.PrintErrors(pkgs) > 0 {
		os.Exit(1)
	}

	// Parse annotations and extract information.
	if err := p.parse(pkgs); err != nil {
		log.Printf("[ERROR] %s\n", err.Error())
		return
	}

	// Check the legality of injectors and cyclic provider dependencies.
	if p.checkInjectorLegal() && p.checkCyclicProvider() {
		// Generate Go code.
		for _, pkg := range p.Packages {
			generator := NewGenerator(pkg)
			generator.Do()
		}
	}
}
